{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 目标巡线 - 训练模型\n",
    "\n",
    "在本笔记本中，我们将训练神经网络以获取输入图像，并输出与目标相对应的一组x，y值。\n",
    "\n",
    "我们将使用PyTorch深度学习框架来训练ResNet18神经网络体系结构模型，以供道路跟随者应用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import glob\n",
    "import PIL.Image\n",
    "import os\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 下载并提取数据\n",
    "\n",
    "\n",
    "开始之前，您应该上载在机器人上的data_collection.ipynb笔记本中创建的road_following_ <Date＆Time> .zip文件。\n",
    "\n",
    ">如果要在收集数据的JetBot上进行培训，则可以跳过此操作！\n",
    "\n",
    "然后，您应该通过调用以下命令来提取该数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -q road_following.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "您应该看到一个名为``dataset_all``的文件夹出现在文件浏览器中。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 建立数据集实例\n",
    "\n",
    "在这里，我们创建了一个自定义的Torch.utils.data.Dataset实现，该实现实现了__len__和__getitem__函数。 这节课\n",
    "负责加载图像并从图像文件名中解析x，y值。 由于我们实现了“ torch.utils.data.Dataset”类，\n",
    "我们可以使用所有的火炬数据实用程序：）\n",
    "\n",
    "我们将一些转换（例如颜色抖动）硬编码到我们的数据集中。 我们将随机的水平翻转设为可选（如果您想遵循非对称路径，例如道路\n",
    "我们需要“保持正确”的位置）。 如果您的机器人是否遵循某些约定并不重要，则可以启用翻转功能以扩充数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_x(path):\n",
    "    \"\"\"Gets the x value from the image filename\"\"\"\n",
    "    return (float(int(path[3:6])) - 50.0) / 50.0\n",
    "\n",
    "def get_y(path):\n",
    "    \"\"\"Gets the y value from the image filename\"\"\"\n",
    "    return (float(int(path[7:10])) - 50.0) / 50.0\n",
    "\n",
    "class XYDataset(torch.utils.data.Dataset):\n",
    "    \n",
    "    def __init__(self, directory, random_hflips=False):\n",
    "        self.directory = directory\n",
    "        self.random_hflips = random_hflips\n",
    "        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))\n",
    "        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.image_paths)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        image_path = self.image_paths[idx]\n",
    "        \n",
    "        image = PIL.Image.open(image_path)\n",
    "        x = float(get_x(os.path.basename(image_path)))\n",
    "        y = float(get_y(os.path.basename(image_path)))\n",
    "        \n",
    "        if float(np.random.rand(1)) > 0.5:\n",
    "            image = transforms.functional.hflip(image)\n",
    "            x = -x\n",
    "        \n",
    "        image = self.color_jitter(image)\n",
    "        image = transforms.functional.resize(image, (224, 224))\n",
    "        image = transforms.functional.to_tensor(image)\n",
    "        image = image.numpy()[::-1].copy()\n",
    "        image = torch.from_numpy(image)\n",
    "        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        \n",
    "        return image, torch.tensor([x, y]).float()\n",
    "    \n",
    "dataset = XYDataset('dataset_xy', random_hflips=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集分割为训练集和测试集\n",
    "读取数据集后，我们将数据集分为训练集和测试集。 在此示例中，我们拆分训练并测试90％-10％。 测试集将用于验证我们训练的模型的准确性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_percent = 0.1\n",
    "num_test = int(test_percent * len(dataset))\n",
    "train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建数据加载器以批量加载数据\n",
    "\n",
    "我们使用DataLoader类来批量加载数据，随机播放数据并允许使用多个子进程。 在此示例中，我们使用64的批处理大小。批处理大小将基于GPU可用的内存，这会影响模型的准确性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=16,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=16,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义神经网络 \n",
    "\n",
    "我们使用PyTorch TorchVision上可用的ResNet-18模型。\n",
    "\n",
    "在称为转移学习的过程中，我们可以将预训练的模型（在数百万张图像上进行训练）重新用于可能需要更少数据的新任务。\n",
    "\n",
    "\n",
    "有关ResNet-18的更多详细信息：https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py\n",
    "\n",
    "有关转学的更多详细信息：https://www.youtube.com/watch?v=yofjFQddwHE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnet18(pretrained=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ResNet模型已完全连接（fc）最终层，其中512为输入特征（in_features），我们将进行回归训练，因此输出特征（out_features）为1\n",
    "\n",
    "最后，我们转移模型以在GPU上执行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fc = torch.nn.Linear(512, 2)\n",
    "device = torch.device('cuda')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练回归:\n",
    "\n",
    "如果损失减少，我们将训练50个时期，并保存最佳模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 70\n",
    "BEST_MODEL_PATH = 'best_steering_model_xy.pth'\n",
    "best_loss = 1e9\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    train_loss = 0.0\n",
    "    for images, labels in iter(train_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(images)\n",
    "        loss = F.mse_loss(outputs, labels)\n",
    "        train_loss += float(loss)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    train_loss /= len(train_loader)\n",
    "    \n",
    "    model.eval()\n",
    "    test_loss = 0.0\n",
    "    for images, labels in iter(test_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        loss = F.mse_loss(outputs, labels)\n",
    "        test_loss += float(loss)\n",
    "    test_loss /= len(test_loader)\n",
    "    \n",
    "    print('%f, %f' % (train_loss, test_loss))\n",
    "    if test_loss < best_loss:\n",
    "        torch.save(model.state_dict(), BEST_MODEL_PATH)\n",
    "        best_loss = test_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练完模型后，它将生成“ best_steering_model_xy.pth”文件，可用于在实时演示中进行测试。\n",
    "\n",
    "如果您在JetBot以外的其他机器上进行培训，则需要将其上载到JetBot的“ road_following”示例文件夹中。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
